import torch
from train import CNN
import torchvision
model = CNN()
stat_dict = torch.load("cnn16.pt")
model.load_state_dict(stat_dict)
model.eval()
example = torch.rand((1, 1, 16, 16))
traced_script = torch.jit.trace(model, example)
traced_script.save("cnn_model16.pt")
